import pandas as pd
import numpy as np
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=SyntaxWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

docs = pd.read_csv('./data/polnli_test_results.csv')

def drop_rare(df, column='target', min_count=10):
    """
    Drops rows from a dataframe where the value in the specified column appears less than min_count times.
    
    Parameters:
    df (pd.DataFrame): The input dataframe.
    column (str): The column to check for frequency.
    min_count (int): The minimum number of occurrences required to keep a row.
    
    Returns:
    pd.DataFrame: The filtered dataframe.
    """
    value_counts = df[column].value_counts()
    valid_values = value_counts[value_counts >= min_count].index
    df = df[df[column].isin(valid_values)]
    df.reset_index(drop = True, inplace = True)
    return df

def bias_metrics(df, bootstrap=True, n_iterations=1000, random_state=None):
    """
    Computes bias metrics for each model, including:
    - Min and max F1 scores
    - Mean F1 score
    - Variance of F1 scores
    - Disparity ratio (max/min F1) with bootstrapped standard errors

    Parameters:
        df : pandas.DataFrame
            DataFrame where rows are models and columns are groups containing F1 scores.
        bootstrap : bool, default True
            Whether to compute bootstrapped standard errors for the disparity ratio.
        n_iterations : int, default 1000
            Number of bootstrap iterations if bootstrap=True.
        random_state : int or None, default None
            Random seed for reproducibility.

    Returns:
        result_df : pandas.DataFrame
            DataFrame with index as model names and columns:
                - 'min_f1': Minimum F1 score for each model.
                - 'max_f1': Maximum F1 score for each model.
                - 'mean_f1': Mean F1 score for each model.
                - 'variance': Variance of F1 scores.
                - 'disparity_ratio': max(F1) / min(F1) for each model.
                - 'std_error': (Only if bootstrap=True) Bootstrapped standard error of the disparity ratio.
    """
    if random_state is not None:
        np.random.seed(random_state)

    results = []

    for model_name, row in df.iterrows():
        values = row.values.astype(float)
        
        min_f1 = np.min(values)
        max_f1 = np.max(values)
        mean_f1 = np.mean(values)
        stdev = np.std(values, ddof=1)  # Sample variance
        
        # Avoid division by zero
        disparity_ratio = np.nan if min_f1 == 0 else max_f1 / min_f1

        std_error = np.nan
        if bootstrap:
            boot_ratios = []
            for _ in range(n_iterations):
                sample = np.random.choice(values, size=len(values), replace=True)
                min_sample = np.min(sample)
                max_sample = np.max(sample)
                
                if min_sample == 0:
                    boot_ratios.append(np.nan)
                else:
                    boot_ratios.append(max_sample / min_sample)
            
            boot_ratios = np.array([r for r in boot_ratios if not np.isnan(r)])
            std_error = np.std(boot_ratios, ddof=1) if boot_ratios.size > 0 else np.nan

        results.append((model_name, min_f1, mean_f1, max_f1, stdev, disparity_ratio, std_error))

    # Create the result DataFrame
    result_df = pd.DataFrame(results, columns=['model', 'min_f1', 'mean_f1', 'max_f1', 'std deviation', 'disparity_ratio', 'std_error'])
    result_df.set_index('model', inplace=True)

    if not bootstrap:
        result_df.drop(columns=['std_error'], inplace=True)

    return result_df

# Column names that will hold results
models = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern',
          'llama',
          'llama70b',
          'sonnet']

# Groups that will be tested
group_hypoths = [
 'This text is attacking the hearing impaired.',
 'This text is attacking people for their sexuality.',
 'This text is attacking transgender people.',
 'This text is attacking people for their race.',
 'This text is attacking people for their gender.',
 'This text is attacking the disabled.',
 'This text is attacking political outgroups.',
 'This text is attacking people for their place of origin.',
 'This text is attacking muslims.',
 'This text is attacking people for their religion.',
 'This text is attacking women.',
 'This text is attacking blacks.',
 'This text is attacking men.',
 'This text is attacking undocumented people.',
 'This text is attacking gay people.',
 'This text is attacking immigrants.',
 'This text is attacking bisexual people.',
 'This text is attacking hindus.',
 'This text is attacking latinos.',
 'This text is attacking straight people.',
 'This text is attacking christians.',
 'This text is attacking people for their age.',
 'This text is attacking whites.',
 'This text is attacking the visually impaired.',
 'This text is attacking the cognitively disabled.',
 'This text is attacking lesbians.',
 'This text is attacking young adults.',
 'This text is attacking asians.',
 'This text is attacking transgender women.',
 'This text is attacking jews.',
 'This text is attacking middle easterners.',
 'This text is attacking pacific islanders.',
 'This text is dehumanizing lesbians.',
 'This text is dehumanizing muslims.',
 'This text is dehumanizing people for their sexuality.',
 'This text is dehumanizing people for their race.',
 'This text is dehumanizing jews.',
 'This text is dehumanizing people for their place of origin.',
 'This text is dehumanizing gay people.',
 'This text is dehumanizing non binary people.',
 'This text is dehumanizing middle easterners.',
 'This text is dehumanizing women.',
 'This text is dehumanizing latinos.',
 'This text is dehumanizing people for their religion.',
 'This text is dehumanizing people for their gender.',
 'This text is dehumanizing blacks.',
 'This text is dehumanizing the disabled.',
 'This text is dehumanizing christians.',
 'This text is dehumanizing whites.',
 'This text is dehumanizing transgender women.',
 'This text is dehumanizing children.',
 'This text is dehumanizing bisexual people.',
 'This text is dehumanizing asians.',
 'This text is dehumanizing straight people.',
 'This text is dehumanizing the neurologically disables.',
 'This text is dehumanizing immigrants.',
 'This text is dehumanizing undocumented people.',
 'This text is dehumanizing pacific islanders.',
 'This text is defending latinos.',
 'This text is defending the visually impaired.',
 'This text is defending people for their place of origin.',
 'This text is defending people for their sexuality.',
 'This text is defending people for their race.',
 'This text is defending lesbians.',
 'This text is defending blacks.',
 'This text is defending hindus.',
 'This text is defending people for their religion.',
 'This text is defending women.',
 'This text is defending people for their gender.',
 'This text is defending native americans.',
 'This text is defending christians.',
 'This text is defending gay people.',
 'This text is defending muslims.',
 'This text is defending bisexual people.',
 'This text is defending the disabled.',
 'This text is defending men.',
 'This text is defending middle easterners.',
 'This text is defending undocumented people.',
 'This text is defending whites.',
 'This text is defending people for their age.',
 'This text is defending the cognitively disabled.',
 'This text is defending immigrants.',
 'This text is defending transgender women.',
 'This text is defending pacific islanders.',
 'This text is defending migrant workers.',
 'This text is defending straight people.',
 'This text is defending asians.',
 'This text is defending transgender men.',
 'This text is defending transgender people.',
 'This text is defending jews.',
 'This text is defending mormons.',
 'This text advocates for violence against jews.',
 'This text advocates for violence against bisexual people.',
 'This text advocates for violence against men.',
 'This text advocates for violence against pacific islanders.',
 'This text advocates for violence against the middle aged.',
 'This text advocates for violence against people for their place of origin.',
 'This text advocates for violence against lesbians.',
 'This text advocates for violence against native americans.',
 'This text advocates for violence against people for their religion.',
 'This text advocates for violence against people for their sexuality.',
 'This text advocates for violence against straight people.',
 'This text advocates for violence against people for their gender.',
 'This text advocates for violence against women.',
 'This text advocates for violence against middle easterners.',
 'This text advocates for violence against blacks.',
 'This text advocates for violence against undocumented people.',
 'This text advocates for violence against christians.',
 'This text advocates for violence against people for their race.',
 'This text advocates for violence against migrant workers.',
 'This text advocates for violence against gay people.',
 'This text advocates for violence against muslims.',
 'This text advocates for violence against immigrants.',
 'This text advocates for violence against atheists.',
 'This text advocates for violence against whites.',
 'This text advocates for violence against latinos.',
 'This text advocates for violence against transgender women.',
 'This text advocates for violence against the disabled.',
 'This text advocates for violence against non binary people.',
 'This text advocates for violence against children.',
]

############
## Table 13
############
# Subset data to relevant documents
groups = docs[docs['hypothesis'].isin(group_hypoths)]
groups = drop_rare(groups, column = 'hypothesis', min_count = 30)
groups.reset_index(drop = True, inplace = True)

groups_dict = [{'text':groups.loc[i, 'premise'], 'text_pair':groups.loc[i, 'augmented_hypothesis']} for i in groups.index]

# Clean group labels
groups['group'] = groups['hypothesis']
groups['group'] = groups['group'].replace('This text is attacking |This text is dehumanizing |This text is defending |This text advocates for violence against ', '', regex = True)
groups['group'] = groups['group'].replace('people for their ', '', regex = True)
groups['group'] = groups['group'].replace('\.', '', regex = True)

# Calculate bias scores
group_res_df = pd.DataFrame()
group_res_dicts = []
for model in models:
    model_f1 = {}
    for group in groups['group'].unique():
        score = f1_score(groups[groups['group'] == group]['entailment'], groups[groups['group'] == group][model])
        model_f1[group] = score
    group_res_dicts.append(model_f1)
group_res_df = pd.DataFrame(group_res_dicts)
group_res_df.index = models
group_bias = bias_metrics(group_res_df*100).round(2)

group_string = group_bias.to_string()

output_file = './tables/table_13.txt'
with open(output_file, 'w') as f:
    f.write(group_string)

############
## Table 14
############
pol = [
 'The author of this text opposes pompeo.',
 'The author of this text opposes democrats.',
 'The author of this text opposes jim jordan.',
 'The author of this text opposes gaetz.',
 'The author of this text opposes raskin.',
 'The author of this text opposes clinton.',
 'The author of this text opposes schiff.',
 'The author of this text opposes biden.',
 'The author of this text opposes romney.',
 'The author of this text opposes trump.',
 'The author of this text opposes paul ryan.',
 'The author of this text opposes kinzinger.',
 'The author of this text opposes manchin.',
 'The author of this text opposes cheney.',
 'The author of this text opposes mike pence.',
 'The author of this text opposes conservatives.',
 'The author of this text opposes pelosi.',
 'The author of this text opposes mcconnell.',
 'The author of this text opposes cruz.',
 'The author of this text opposes mccarthy.',
 'The author of this text opposes boebert.',
 'The author of this text opposes murkowski.',
 'The author of this text opposes desantis.',
 'The author of this text opposes gosar.',
 'The author of this text opposes bernie sanders.',
 'The author of this text opposes cawthorn.',
 'The author of this text opposes ocasio-cortez.',
 'The author of this text opposes schumer.',
 'The author of this text opposes susan collins.',
 'The author of this text opposes pressley.',
 'The author of this text opposes republicans.',
 'The author of this text opposes kamala harris.',
 'The author of this text opposes sinema.',
 'The author of this text opposes hakeem jeffries.',
 'The author of this text opposes liberals.',
 'The author of this text opposes tlaib.',
 'The author of this text opposes marjorie taylor greene.',
 'The author of this text opposes ilhan omar.',
 'The author of this text supports hakeem jeffries.',
 'The author of this text supports democrats.',
 'The author of this text supports paul ryan.',
 'The author of this text supports schiff.',
 'The author of this text supports biden.',
 'The author of this text supports kamala harris.',
 'The author of this text supports clinton.',
 'The author of this text supports trump.',
 'The author of this text supports gaetz.',
 'The author of this text supports susan collins.',
 'The author of this text supports boebert.',
 'The author of this text supports romney.',
 'The author of this text supports pressley.',
 'The author of this text supports sinema.',
 'The author of this text supports mccarthy.',
 'The author of this text supports republicans.',
 'The author of this text supports cheney.',
 'The author of this text supports pelosi.',
 'The author of this text supports marjorie taylor greene.',
 'The author of this text supports tlaib.',
 'The author of this text supports cruz.',
 'The author of this text supports mike pence.',
 'The author of this text supports mcconnell.',
 'The author of this text supports conservatives.',
 'The author of this text supports schumer.',
 'The author of this text supports raskin.',
 'The author of this text supports ocasio-cortez.',
 'The author of this text supports manchin.',
 'The author of this text supports kinzinger.',
 'The author of this text supports pompeo.',
 'The author of this text supports murkowski.',
 'The author of this text supports desantis.',
 'The author of this text supports cawthorn.',
 'The author of this text supports jim jordan.',
 'The author of this text supports gosar.',
 'The author of this text supports cori bush.',
 'The author of this text supports liberals.',
 'The author of this text supports ilhan omar.',
 'The author of this text supports bernie sanders.',
 'The author of this text is neutral towards raskin.',
 'The author of this text is neutral towards kinzinger.',
 'The author of this text is neutral towards hakeem jeffries.',
 'The author of this text is neutral towards democrats.',
 'The author of this text is neutral towards cruz.',
 'The author of this text is neutral towards murkowski.',
 'The author of this text is neutral towards trump.',
 'The author of this text is neutral towards boebert.',
 'The author of this text is neutral towards biden.',
 'The author of this text is neutral towards clinton.',
 'The author of this text is neutral towards sinema.',
 'The author of this text is neutral towards schiff.',
 'The author of this text is neutral towards paul ryan.',
 'The author of this text is neutral towards mike pence.',
 'The author of this text is neutral towards romney.',
 'The author of this text is neutral towards republicans.',
 'The author of this text is neutral towards cawthorn.',
 'The author of this text is neutral towards mccarthy.',
 'The author of this text is neutral towards mcconnell.',
 'The author of this text is neutral towards ilhan omar.',
 'The author of this text is neutral towards pompeo.',
 'The author of this text is neutral towards pelosi.',
 'The author of this text is neutral towards bernie sanders.',
 'The author of this text is neutral towards tlaib.',
 'The author of this text is neutral towards kamala harris.',
 'The author of this text is neutral towards gosar.',
 'The author of this text is neutral towards liberals.',
 'The author of this text is neutral towards conservatives.',
 'The author of this text is neutral towards marjorie taylor greene.',
 'The author of this text is neutral towards manchin.',
 'The author of this text is neutral towards schumer.',
 'The author of this text is neutral towards susan collins.',
 'The author of this text is neutral towards pressley.',
 'The author of this text is neutral towards taylor greene.',
 'The author of this text is neutral towards jim jordan.',
 'The author of this text is neutral towards ocasio-cortez.',
 'The author of this text is neutral towards desantis.',
 'The author of this text is neutral towards gaetz.',
'The author of this text believes in the use of affirmative action.',
'The author of this text does not believe that democratic governments should require voters to present photo identification at the polling station.',
'The author of this text supports passing the American Jobs Act.',
'The author of this text believes we should subsidize student loans.',
'The author of this text supports limiting the right to bear arms.',
'The author of this text supports voter IDs.',
'The author of this text does not believe that Israel should lift the blockade of Gaza.',
'The author of this text believes we should end affirmative action.',
'The author of this text believes we should adopt an austerity regime.',
'The author of this text believes we should adopt libertarianism.',
]

targets = [
    'pompeo', 
    'democrats', 
    'jim jordan', 
    'gaetz', 
    'hakeem jeffries',
    'raskin', 
    'paul ryan', 
    'schiff', 
    'kinzinger', 
    'clinton', 
    'biden',
    'romney', 
    'cruz', 
    'murkowski', 
    'kamala harris', 
    'trump', 
    'boebert', 
    'manchin', 
    'cheney', 
    'susan collins', 
    'mike pence', 
    'conservatives', 
    'sinema', 
    'pressley', 
    'republicans', 
    'pelosi', 
    'cawthorn', 
    'mccarthy', 
    'mcconnell', 
    'ilhan omar', 
    'marjorie taylor greene', 
    'tlaib', 
    'bernie sanders', 
    'desantis', 
    'gosar', 
    'liberals', 
    'ocasio-cortez', 
    'schumer', 
    'taylor greene', 
    'cori bush',
]

lib = [
    'democrats', 'hakeem jeffries', 'raskin', 'schiff', 'clinton', 'biden', 
    'kamala harris', 'manchin', 'pressley', 'pelosi', 'ilhan omar', 
    'tlaib', 'bernie sanders', 'liberals', 'ocasio-cortez', 'schumer', 
    'cori bush', 'sinema',     'believes in the use of affirmative action.',
    'does not believe that democratic governments should require voters to present photo identification at the polling station.',
    'supports passing the American Jobs Act.',
    'believes we should subsidize student loans.',
    'supports limiting the right to bear arms.',
]


con = [
    'pompeo', 'jim jordan', 'gaetz', 'paul ryan', 'kinzinger', 'romney', 
    'cruz', 'murkowski', 'trump', 'boebert', 'cheney', 'susan collins', 
    'mike pence', 'conservatives', 'republicans', 'cawthorn', 
    'mccarthy', 'mcconnell', 'marjorie taylor greene', 'desantis', 
    'gosar', 'taylor greene',        'supports voter IDs.',
       'does not believe that Israel should lift the blockade of Gaza.',
       'believes we should end affirmative action.',
       'believes we should adopt an austerity regime.',
       'believes we should adopt libertarianism.',
]

libtopics = [
    'The author of this text believes in the use of affirmative action.',
    'The author of this text does not believe that democratic governments should require voters to present photo identification at the polling station.',
    'The author of this text supports passing the American Jobs Act.',
    'The author of this text believes we should subsidize student loans.',
    'The author of this text supports limiting the right to bear arms.',
]

contopics = [
       'The author of this text supports voter IDs.',
       'The author of this text does not believe that Israel should lift the blockade of Gaza.',
       'The author of this text believes we should end affirmative action.',
       'The author of this text believes we should adopt an austerity regime.',
       'The author of this text believes we should adopt libertarianism.',
]

topics = ['believes in the use of affirmative action.',
    'does not believe that democratic governments should require voters to present photo identification at the polling station.',
    'supports passing the American Jobs Act.',
    'believes we should subsidize student loans.',
    'supports limiting the right to bear arms.',
          'supports voter IDs.',
       'does not believe that Israel should lift the blockade of Gaza.',
       'believes we should end affirmative action.',
       'believes we should adopt an austerity regime.',
       'believes we should adopt libertarianism.',
         ]

# Subset to political data
poldf = docs[docs['hypothesis'].isin(pol)]
poldf.reset_index(drop = True, inplace = True)
pol_dict = [{'text':poldf.loc[i, 'premise'], 'text_pair':poldf.loc[i, 'augmented_hypothesis']} for i in poldf.index]

# Clean labels
poldf['target'] = poldf['hypothesis']
poldf['target'] = poldf['target'].replace('The author of this text is neutral towards |The author of this text supports |The author of this text opposes |The author of this text ', '', regex = True)
poldf['target'] = poldf['target'].replace('\.', '', regex = True)

# Binary conservative variable
poldf['con'] = [1 if target in con else 0 for target in poldf['target']]

# Calculate bias metrics
poldf = drop_rare(poldf, column = 'target', min_count = 30)
pol_res_df = pd.DataFrame()
pol_res_dicts = []
for model in models:
    model_f1 = {}
    for target in poldf['target'].unique():
        score = f1_score(poldf[poldf['target'] == target]['entailment'], poldf[poldf['target'] == target][model])
        model_f1[target] = score
    pol_res_dicts.append(model_f1)
pol_res_df = pd.DataFrame(pol_res_dicts)
pol_res_df['model'] = models
pol_res_df.set_index('model', inplace = True)

pol_bias = bias_metrics(pol_res_df*100).round(2)

pol_string = pol_bias.to_string()

output_file = './tables/table_14.txt'
with open(output_file, 'w') as f:
    f.write(pol_string)